package com.boot.socket5.server.compont;

import lombok.extern.slf4j.Slf4j;

import java.io.*;
import java.net.*;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;


@Slf4j
public class SocksServerOneThread implements Runnable {

    /**
     * 来源的代理socket
     */
    private final Socket socket;
    /**
     * 是否开启socks4代理
     */
    private final boolean openSock4;
    /**
     * 是否开启socks5代理
     */
    private final boolean openSock5;
    /**
     * socks5代理的登录用户名，如果 不为空表示需要登录验证
     */
    private final String user;
    /**
     * socks5代理的登录密码，
     */
    private final String pwd;
    /**
     * socks是否需要进行登录验证
     */
    private final boolean socksNeekLogin;

    /**
     * @param socket    来源的代理socket
     * @param openSock4 是否开启socks4代理
     * @param openSock5 是否开启socks5代理
     * @param user      socks5代理的登录用户名，如果 不为空表示需要登录验证
     * @param pwd       socks5代理的登录密码，
     */
    public SocksServerOneThread(Socket socket, boolean openSock4, boolean openSock5, String user, String pwd, boolean socksNeekLogin) {
        this.socket = socket;
        this.openSock4 = openSock4;
        this.openSock5 = openSock5;
        this.user = user;
        this.pwd = pwd;
        this.socksNeekLogin = socksNeekLogin;
    }

    @Override
    public void run() {
//        获取来源的地址用于打印
        String addr = socket.getRemoteSocketAddress().toString();
        log.info("process one socket : {}", addr);

//        声明流
        InputStream a_in, b_in;
        OutputStream a_out, b_out;
        Socket proxy_socket = null;
        ByteArrayOutputStream cache = null;

        try {
            a_in = socket.getInputStream();
            a_out = socket.getOutputStream();

            byte[] tmp = new byte[1];
            int n = a_in.read();
            if (n==1) {
                byte protocol = tmp[0];
                if ((openSock4 && 0x04 == protocol)) {// 如果开启代理4，并以socks4协议请求
                    proxy_socket = sock4Check(a_in, a_out);
                 }else if ((openSock5 && 0x05 == protocol)) {// 如果开启代理5，并以socks5协议请求
                    proxy_socket = sock5Check(a_in, a_out);
                } else {// 非socks 4 ,5 协议的请求
                    log.info("not socks proxy : %s  openSock4[] openSock5[]", tmp[0], openSock4, openSock5);
                }

                if (proxy_socket != null) {
                    CountDownLatch latch = new CountDownLatch(1);
                    b_in = proxy_socket.getInputStream();
                    b_out = proxy_socket.getOutputStream();
                    // 交换流数据
                    if (80 == proxy_socket.getPort()) {
                        cache = new ByteArrayOutputStream();
                    }
                    transfer(latch, a_in, b_out, cache);
                    transfer(latch, b_in, a_out, cache);

                    try {
                        latch.await();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }

            }else {

            }



        } catch (IOException e) {
            e.printStackTrace();
        }


    }

    /**socket4 头处理
     * @param in
     * @param out
     * @return
     * @throws IOException
     */
    private Socket sock4Check(InputStream in, OutputStream out) throws IOException {
        Socket proxy_socket = null;
        byte[] tmp = new byte[3];
        in.read(tmp);
        // 请求协议|VN1|CD1|DSTPORT2|DSTIP4|NULL1|
        int port = ByteBuffer.wrap(tmp, 1, 2).asShortBuffer().get() & 0xFFFF;
        String host = getHost((byte) 0x01, in);
        in.read();
        byte[] rsv = new byte[8];// 返回一个8位的响应协议
        // |VN1|CD1|DSTPORT2|DSTIP 4|
        try {
            proxy_socket = new Socket(host, port);
            log.info("connect [%s] %s:%s", tmp[1], host, port);
            rsv[1] = 90;// 代理成功
        } catch (Exception e) {
            log.info("connect exception  %s:%s", host, port);
            rsv[1] = 91;// 代理失败.
        }
        out.write(rsv);
        out.flush();
        return proxy_socket;
    }



    /**
     * sock5代理头处理
     *
     * @param in
     * @param out
     * @return
     * @throws IOException
     */
    private Socket sock5Check(InputStream in, OutputStream out) throws IOException {
        byte[] tmp = new byte[2];
        in.read(tmp);
        boolean isLogin = false;
        byte method = tmp[1];
        if (0x02 == tmp[0]) {
            method = 0x00;
            in.read();
        }
        if (socksNeekLogin) {
            method = 0x02;
        }
        tmp = new byte[] { 0x05, method };
        out.write(tmp);
        out.flush();
        // Socket result = null;
        Object resultTmp = null;
        if (0x02 == method) {// 处理登录.
            int b = in.read();
            String user = null;
            String pwd = null;
            if (0x01 == b) {
                b = in.read();
                tmp = new byte[b];
                in.read(tmp);
                user = new String(tmp);
                b = in.read();
                tmp = new byte[b];
                in.read(tmp);
                pwd = new String(tmp);
                if (null != user && user.trim().equals(this.user) && null != pwd && pwd.trim().equals(this.pwd)) {// 权限过滤
                    isLogin = true;
                    tmp = new byte[] { 0x05, 0x00 };// 登录成功
                    out.write(tmp);
                    out.flush();
                    log.info("%s login success !", user);
                } else {
                    log.info("%s login faild !", user);
                }
            }
        }
        byte cmd = 0;
        if (!socksNeekLogin || isLogin) {// 验证是否需要登录
            tmp = new byte[4];
            in.read(tmp);
            log.info("proxy header >>  %s", Arrays.toString(tmp));
            cmd = tmp[1];
            String host = getHost(tmp[3], in);
            tmp = new byte[2];
            in.read(tmp);
            int port = ByteBuffer.wrap(tmp).asShortBuffer().get() & 0xFFFF;
            log.info("connect %s:%s", host, port);
            ByteBuffer rsv = ByteBuffer.allocate(10);
            rsv.put((byte) 0x05);
            try {
                if (0x01 == cmd) {
                    resultTmp = new Socket(host, port);
                    rsv.put((byte) 0x00);
                } else if (0x02 == cmd) {
                    resultTmp = new ServerSocket(port);
                    rsv.put((byte) 0x00);
                } else {
                    rsv.put((byte) 0x05);
                    resultTmp = null;
                }
            } catch (Exception e) {
                rsv.put((byte) 0x05);
                resultTmp = null;
            }
            rsv.put((byte) 0x00);
            rsv.put((byte) 0x01);
            rsv.put(socket.getLocalAddress().getAddress());
            Short localPort = (short) ((socket.getLocalPort()) & 0xFFFF);
            rsv.putShort(localPort);
            tmp = rsv.array();
        } else {
            tmp = new byte[] { 0x05, 0x01 };// 登录失败
            log.info("socks server need login,but no login info .");
        }
        out.write(tmp);
        out.flush();
        if (null != resultTmp && 0x02 == cmd) {
            ServerSocket ss = (ServerSocket) resultTmp;
            try {
                resultTmp = ss.accept();
            } catch (Exception e) {
            } finally {
                closeIo(ss);
            }
        }
        return (Socket) resultTmp;
    }


    /**获取目标服务器地址
     * @param type
     * @param in
     * @return
     * @throws IOException
     */
    private String getHost(byte type, InputStream in) throws IOException {
        String host = null;
        byte[] tmp = null;
        switch (type) {
            case 0x01:// IPV4协议
                tmp = new byte[4];
                in.read(tmp);
                host = InetAddress.getByAddress(tmp).getHostAddress();
                break;
            case 0x03:// 使用域名
                int l = in.read();
                tmp = new byte[l];
                in.read(tmp);
                host = new String(tmp);
                break;
            case 0x04:// 使用IPV6
                tmp = new byte[16];
                in.read(tmp);
                host = InetAddress.getByAddress(tmp).getHostAddress();
                break;
            default:
                break;
        }
        return host;
    }

    /**
     * IO操作中共同的关闭方法
     *
     * @createTime 2014年12月14日 下午7:50:56
     * @param closeable
     */
    protected static final void closeIo(Socket closeable) {
        if (null != closeable) {
            try {
                closeable.close();
            } catch (IOException e) {
            }
        }
    }

    /**
     * IO操作中共同的关闭方法
     *
     * @createTime 2014年12月14日 下午7:50:56
     * @param closeable
     */
    protected static final void closeIo(Closeable closeable) {
        if (null != closeable) {
            try {
                closeable.close();
            } catch (IOException e) {
            }
        }
    }


    /**
     * 数据交换.主要用于tcp协议的交换
     *
     * @createTime 2014年12月13日 下午11:06:47
     * @param
     *
     * @param in
     *            输入流
     * @param out
     *            输出流
     */
    protected static final void transfer(final CountDownLatch latch, final InputStream in, final OutputStream out,
                                         final OutputStream cache) {
        new Thread() {
            public void run() {
                byte[] bytes = new byte[1024];
                int n = 0;
                try {
                    while ((n = in.read(bytes)) > 0) {
                        out.write(bytes, 0, n);
                        out.flush();
                        if (null != cache) {
                            synchronized (cache) {
                                cache.write(bytes, 0, n);
                            }
                        }
                    }
                } catch (Exception e) {
                }
                if (null != latch) {
                    latch.countDown();
                }
            };
        }.start();
    }

}
